﻿using System.Security.Cryptography;
using System.Xml;

namespace ExcelDataReader.Core.OfficeCrypto;

/// <summary>
/// Represents "Agile Encryption" used in XLSX (Office 2010 and newer).
/// </summary>
internal sealed class AgileEncryption : EncryptionInfo
{
    private const string NEncryption = "encryption";
    private const string NKeyData = "keyData";
    private const string NKeyEncryptors = "keyEncryptors";
    private const string NKeyEncryptor = "keyEncryptor";
    private const string NEncryptedKey = "encryptedKey";
    private const string NsEncryption = "http://schemas.microsoft.com/office/2006/encryption";
    private const string NsPassword = "http://schemas.microsoft.com/office/2006/keyEncryptor/password";

    private static readonly byte[] _inputSuffix = [0xfe, 0xa7, 0xd2, 0x76, 0x3b, 0x4b, 0x9e, 0x79];
    private static readonly byte[] _valueSuffix = [0xd7, 0xaa, 0x0f, 0x6d, 0x30, 0x61, 0x34, 0x4e];
    
    public AgileEncryption(byte[] bytes)
    {
        using var stream = new MemoryStream(bytes, 8, bytes.Length - 8);
        using var xmlReader = XmlReader.Create(stream);

        ReadXmlEncryptionInfoStream(xmlReader);
    }

    public CipherIdentifier CipherAlgorithm { get; set; }

    public CipherMode CipherChaining { get; set; }

    public HashIdentifier HashAlgorithm { get; set; }

    public int KeyBits { get; set; }

    public int BlockSize { get; set; }

    public int HashSize { get; set; }

    public byte[] SaltValue { get; set; }

    public byte[] PasswordSaltValue { get; set; }

    public CipherIdentifier PasswordCipherAlgorithm { get; set; }

    public CipherMode PasswordCipherChaining { get; set; }

    public HashIdentifier PasswordHashAlgorithm { get; set; }

    public byte[] PasswordEncryptedKeyValue { get; set; }

    public byte[] PasswordEncryptedVerifierHashInput { get; set; }

    public byte[] PasswordEncryptedVerifierHashValue { get; set; }

    public int PasswordSpinCount { get; set; }

    public int PasswordKeyBits { get; set; }

    public int PasswordBlockSize { get; set; }

    public override bool IsXor => false;

    public override SymmetricAlgorithm CreateCipher()
    {
        return CryptoHelpers.CreateCipher(CipherAlgorithm, KeyBits, BlockSize * 8, CipherChaining);
    }

    public override byte[] GenerateSecretKey(string password)
    {
        using var cipher = CryptoHelpers.CreateCipher(PasswordCipherAlgorithm, PasswordKeyBits, PasswordBlockSize * 8, PasswordCipherChaining);

        return GenerateSecretKey(password, PasswordSaltValue, PasswordHashAlgorithm, PasswordEncryptedKeyValue, PasswordSpinCount, PasswordKeyBits, cipher);
    }

    public override byte[] GenerateBlockKey(int blockNumber, byte[] secretKey)
    {
        var salt = CryptoHelpers.HashBytes(CryptoHelpers.Combine(secretKey, BitConverter.GetBytes(blockNumber)), HashAlgorithm);
        Array.Resize(ref salt, BlockSize);
        return salt;
    }

    public override Stream CreateEncryptedPackageStream(Stream stream, byte[] secretKey)
    {
        return new AgileEncryptedPackageStream(stream, secretKey, SaltValue, this);
    }

    public override bool VerifyPassword(string password)
    {
        byte[] secretKey;
        using (var hashAlgorithm = CryptoHelpers.Create(PasswordHashAlgorithm))
            secretKey = HashPassword(password, PasswordSaltValue, hashAlgorithm, PasswordSpinCount);

        var inputBlockKey = CryptoHelpers.HashBytes(
            CryptoHelpers.Combine(secretKey, _inputSuffix), PasswordHashAlgorithm);
        Array.Resize(ref inputBlockKey, PasswordKeyBits / 8);

        var valueBlockKey = CryptoHelpers.HashBytes(
            CryptoHelpers.Combine(secretKey, _valueSuffix),
            PasswordHashAlgorithm);
        Array.Resize(ref valueBlockKey, PasswordKeyBits / 8);

        using var cipher = CryptoHelpers.CreateCipher(PasswordCipherAlgorithm, PasswordKeyBits, PasswordBlockSize * 8, PasswordCipherChaining);

        var decryptedVerifier = CryptoHelpers.DecryptBytes(cipher, PasswordEncryptedVerifierHashInput, inputBlockKey, PasswordSaltValue);
        var decryptedVerifierHash = CryptoHelpers.DecryptBytes(cipher, PasswordEncryptedVerifierHashValue, valueBlockKey, PasswordSaltValue);

        var verifierHash = CryptoHelpers.HashBytes(decryptedVerifier, PasswordHashAlgorithm);
        for (var i = 0; i < Math.Min(decryptedVerifierHash.Length, verifierHash.Length); ++i)
        {
            if (decryptedVerifierHash[i] != verifierHash[i])
                return false;
        }

        return true;
    }

    private static byte[] GenerateSecretKey(string password, byte[] saltValue, HashIdentifier hashIdentifier, byte[] encryptedKeyValue, int spinCount, int keyBits, SymmetricAlgorithm cipher)
    {
        var block3 = new byte[] { 0x14, 0x6e, 0x0b, 0xe7, 0xab, 0xac, 0xd0, 0xd6 };

        byte[] hash;
        using (var hashAlgorithm = CryptoHelpers.Create(hashIdentifier))
        {
            hash = HashPassword(password, saltValue, hashAlgorithm, spinCount);

            hash = CryptoHelpers.HashBytes(CryptoHelpers.Combine(hash, block3), hashIdentifier);
        }

        // Truncate or pad with 0x36
        var hashSize = hash.Length;
        Array.Resize(ref hash, keyBits / 8);
        for (var i = hashSize; i < keyBits / 8; i++)
        {
            hash[i] = 0x36;
        }

        // NOTE: the stored salt is padded to a multiple of the block size which affects AES-192
        var decryptedKeyValue = CryptoHelpers.DecryptBytes(cipher, encryptedKeyValue, hash, saltValue);
        Array.Resize(ref decryptedKeyValue, keyBits / 8);
        return decryptedKeyValue;
    }

#if NETSTANDARD2_1_OR_GREATER || NET8_0_OR_GREATER
    private static byte[] HashPassword(string password, byte[] saltValue, HashAlgorithm hashAlgorithm, int spinCount)
    {
        var saltAndPassword = CryptoHelpers.Combine(saltValue, System.Text.Encoding.Unicode.GetBytes(password));
        var hash = hashAlgorithm.ComputeHash(saltAndPassword);

        var rented = System.Buffers.ArrayPool<byte>.Shared.Rent(4 + hash.Length);
        var iterationData = rented.AsSpan(0, 4 + hash.Length);

        for (var i = 0; i < spinCount; i++)
        {
            BitConverter.TryWriteBytes(iterationData[..4], i);
            hash.CopyTo(iterationData[4..]);
            hashAlgorithm.TryComputeHash(iterationData, hash, out _);
        }

        System.Buffers.ArrayPool<byte>.Shared.Return(rented);

        return hash;
    }
#else
    private static byte[] HashPassword(string password, byte[] saltValue, HashAlgorithm hashAlgorithm, int spinCount)
    {
        var saltAndPassword = CryptoHelpers.Combine(saltValue, System.Text.Encoding.Unicode.GetBytes(password));
        var hash = hashAlgorithm.ComputeHash(saltAndPassword);

        var iterationData = new byte[4 + hash.Length];

        for (var i = 0; i < spinCount; i++)
        {
            var buffer = BitConverter.GetBytes(i);
            Buffer.BlockCopy(buffer, 0, iterationData, 0, 4);
            Buffer.BlockCopy(hash, 0, iterationData, 4, hash.Length);
            hash = hashAlgorithm.ComputeHash(iterationData, 0, iterationData.Length);
        }

        return hash;
    }
#endif

    private static HashIdentifier ParseHash(string value)
    {
#if NETSTANDARD2_1_OR_GREATER || NET8_0_OR_GREATER
        return Enum.Parse<HashIdentifier>(value);
#else
        return (HashIdentifier)Enum.Parse(typeof(HashIdentifier), value);
#endif
    }

    private static CipherIdentifier ParseCipher(string value/*, int blockBits*/) => value switch
    {
        "AES" => CipherIdentifier.AES,
        "DES" => CipherIdentifier.DES,
        "3DES" => CipherIdentifier.DES3,
        "RC2" => CipherIdentifier.RC2,
        _ => throw new ArgumentException("Unknown encryption: " + value, nameof(value)),
    };

    private static CipherMode ParseCipherMode(string value) => value switch
    {
        "ChainingModeCBC" => CipherMode.CBC,
        "ChainingModeCFB" => CipherMode.CFB,
        _ => throw new ArgumentException("Invalid CipherMode " + value),
    };

    private void ReadXmlEncryptionInfoStream(XmlReader xmlReader)
    {
        if (!xmlReader.IsStartElement(NEncryption, NsEncryption))
        {
            return;
        }

        if (!XmlReaderHelper.ReadFirstContent(xmlReader))
        {
            return;
        }

        while (!xmlReader.EOF)
        {
            if (xmlReader.IsStartElement(NKeyData, NsEncryption))
            {
                // <keyData saltSize="16" blockSize="16" keyBits="256" hashSize="64" cipherAlgorithm="AES" cipherChaining="ChainingModeCBC" hashAlgorithm="SHA512" saltValue="zYmgeIEW4PVmYPiNJItVCQ=="/>
                // <dataIntegrity encryptedHmacKey="v11xCwbBfQ6Wq03h2M6Nh5Z9fwNnFQwEzu8vmBDps55kd+HfLDzrnuKzuQq4tlpxW0nX99VWh+n2X6ukU6v9FQ==" encryptedHmacValue="SvDwFQR4dNsXOzNstFWHqSHpAUWHQvAr63IhxlxhlQEAczDPIwCWD32aIEFipY7NOlW+LvYPaKC8zO1otxit2g=="/>
                // <keyEncryptors><keyEncryptor uri="http://schemas.microsoft.com/office/2006/keyEncryptor/password"><p:encryptedKey spinCount="100000" saltSize="16" blockSize="16" keyBits="256" hashSize="64" cipherAlgorithm="AES" cipherChaining="ChainingModeCBC" hashAlgorithm="SHA512" saltValue="n37HW2mNfJuGwVxTeBY1LA==" encryptedVerifierHashInput="2Y2Oo+QDyMdo327gZUcejA==" encryptedVerifierHashValue="PmkCD5y5cHqMQqbgACUgxLRgISYZL6+jj3K0PSrFDWlEG+fjzFevIee1FubgdpY2P22IIM6W7C/bXE0ayAo8yg==" encryptedKeyValue="qzkvVPIBy2Bk/w2/fp+hhpq5sPReA8aUu414/Xh7494="/></keyEncryptor></keyEncryptors>
                var cipherAlgorithm = xmlReader.GetAttribute("cipherAlgorithm");
                var cipherChaining = xmlReader.GetAttribute("cipherChaining");
                var hashAlgorithm = xmlReader.GetAttribute("hashAlgorithm");
                var saltValue = xmlReader.GetAttribute("saltValue");

                // int.TryParse(xmlReader.GetAttribute("saltSize"), out int saltSize);
#pragma warning disable CA1806 // Do not ignore method results
                int.TryParse(xmlReader.GetAttribute("blockSize"), out int blockSize);
                int.TryParse(xmlReader.GetAttribute("keyBits"), out int keyBits);
                int.TryParse(xmlReader.GetAttribute("hashSize"), out int hashSize);
#pragma warning restore CA1806 // Do not ignore method results

                SaltValue = Convert.FromBase64String(saltValue);
                HashSize = hashSize; // given in bytes, also given implicitly by SHA512
                KeyBits = keyBits;
                BlockSize = blockSize;
                CipherAlgorithm = ParseCipher(cipherAlgorithm/*, blockSize * 8*/);
                CipherChaining = ParseCipherMode(cipherChaining);
                HashAlgorithm = ParseHash(hashAlgorithm);
                xmlReader.Skip();
            }
            else if (xmlReader.IsStartElement(NKeyEncryptors, NsEncryption))
            {
                ReadKeyEncryptors(xmlReader);
            }
            else if (!XmlReaderHelper.SkipContent(xmlReader))
            {
                break;
            }
        }
    }

    private void ReadKeyEncryptors(XmlReader xmlReader)
    {
        if (!XmlReaderHelper.ReadFirstContent(xmlReader))
        {
            return;
        }

        while (!xmlReader.EOF)
        {
            if (xmlReader.IsStartElement(NKeyEncryptor, NsEncryption))
            {
                // <keyEncryptor uri="http://schemas.microsoft.com/office/2006/keyEncryptor/password">
                ReadKeyEncryptor(xmlReader);
            }
            else if (!XmlReaderHelper.SkipContent(xmlReader))
            {
                break;
            }
        }
    }

    private void ReadKeyEncryptor(XmlReader xmlReader)
    {
        if (!XmlReaderHelper.ReadFirstContent(xmlReader))
        {
            return;
        }

        while (!xmlReader.EOF)
        {
            if (xmlReader.IsStartElement(NEncryptedKey, NsPassword))
            {
                // <p:encryptedKey spinCount="100000" saltSize="16" blockSize="16" keyBits="256" hashSize="64" cipherAlgorithm="AES" cipherChaining="ChainingModeCBC" hashAlgorithm="SHA512" saltValue="n37HW2mNfJuGwVxTeBY1LA==" encryptedVerifierHashInput="2Y2Oo+QDyMdo327gZUcejA==" encryptedVerifierHashValue="PmkCD5y5cHqMQqbgACUgxLRgISYZL6+jj3K0PSrFDWlEG+fjzFevIee1FubgdpY2P22IIM6W7C/bXE0ayAo8yg==" encryptedKeyValue="qzkvVPIBy2Bk/w2/fp+hhpq5sPReA8aUu414/Xh7494="/></keyEncryptor></keyEncryptors>
                var cipherAlgorithm = xmlReader.GetAttribute("cipherAlgorithm");
                var cipherChaining = xmlReader.GetAttribute("cipherChaining");
                var hashAlgorithm = xmlReader.GetAttribute("hashAlgorithm");
                var saltValue = xmlReader.GetAttribute("saltValue");
                var encryptedVerifierHashInput = xmlReader.GetAttribute("encryptedVerifierHashInput");
                var encryptedVerifierHashValue = xmlReader.GetAttribute("encryptedVerifierHashValue");
                var encryptedKeyValue = xmlReader.GetAttribute("encryptedKeyValue");

                // int.TryParse(xmlReader.GetAttribute("saltSize"), out int saltSize);
                // int.TryParse(xmlReader.GetAttribute("hashSize"), out int hashSize);
#pragma warning disable CA1806 // Do not ignore method results
                int.TryParse(xmlReader.GetAttribute("spinCount"), out int spinCount);
                int.TryParse(xmlReader.GetAttribute("blockSize"), out int blockSize);
                int.TryParse(xmlReader.GetAttribute("keyBits"), out int keyBits);
#pragma warning restore CA1806 // Do not ignore method results

                PasswordSaltValue = Convert.FromBase64String(saltValue);
                PasswordCipherAlgorithm = ParseCipher(cipherAlgorithm/*, blockSize * 8*/);
                PasswordCipherChaining = ParseCipherMode(cipherChaining);
                PasswordHashAlgorithm = ParseHash(hashAlgorithm);
                PasswordEncryptedKeyValue = Convert.FromBase64String(encryptedKeyValue);
                PasswordEncryptedVerifierHashInput = Convert.FromBase64String(encryptedVerifierHashInput);
                PasswordEncryptedVerifierHashValue = Convert.FromBase64String(encryptedVerifierHashValue);
                PasswordSpinCount = spinCount;
                PasswordKeyBits = keyBits;
                PasswordBlockSize = blockSize;

                xmlReader.Skip();
            }
            else if (!XmlReaderHelper.SkipContent(xmlReader))
            {
                break;
            }
        }
    }
}
